import numpy as np
import torch
from torchvision import transforms
from transformers import CLIPVisionModel, CLIPVisionConfig
from scipy.io import loadmat
import pickle


def get_clip_features(images):
    features_list = []
    i = 0
    for img in images:
        img = preprocess(img)
        img = img.unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = model(img).pooler_output
        features_list.append(image_features.cpu().numpy())
        print(f"Image {i} processed")
        i += 1
    return np.vstack(features_list)


# set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# initialize CLIP
config = CLIPVisionConfig.from_pretrained("./clip-vit-base-patch32")
model = CLIPVisionModel(config).to(device)
print("CLIP Loaded Successfully")

model.eval()

# load SVHN dataset
train_data = loadmat("data/svhn/train_32x32.mat")

# [h, w, c, num] -> [num, h, w, c]
train_images = np.transpose(train_data['X'], (3, 0, 1, 2))

# preprocess images
preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# get image embeddings
features = get_clip_features(train_images)

# save image embeddings
with open('image_embeddings_svhn_train.pkl', 'wb') as f:
    pickle.dump(features, f)
